{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "run_control": {
     "marked": true
    }
   },
   "source": [
    "## Description：\n",
    "这里是通过一个简单的小案例来看看如何通过掉包的方式使用FM模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:51.274471Z",
     "start_time": "2020-09-26T00:46:50.076694Z"
    }
   },
   "outputs": [],
   "source": [
    "# 导入包\n",
    "from pyfm import pylibfm\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用这个类最简单的方式就是把数据存成字典的形式， 然后用DictVectorizer进行one-hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:52.927599Z",
     "start_time": "2020-09-26T00:46:52.922647Z"
    }
   },
   "outputs": [],
   "source": [
    "train = [\n",
    "    {'user': '1', 'item': '5', 'age': 19},\n",
    "    {'user': '2', 'item': '43', 'age': 33},\n",
    "    {'user': '3', 'item': '20', 'age': 55},\n",
    "    {'user': '4', 'item': '10', 'age': 20}\n",
    "]\n",
    "v = DictVectorizer()\n",
    "X = v.fit_transform(train)      # 本身被压缩了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:54.612205Z",
     "start_time": "2020-09-26T00:46:54.603196Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[19.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.],\n",
       "       [33.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.],\n",
       "       [55.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.],\n",
       "       [20.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T11:33:20.254650Z",
     "start_time": "2020-09-25T11:33:20.250660Z"
    }
   },
   "outputs": [],
   "source": [
    "y = np.repeat(1, X.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T11:33:37.234732Z",
     "start_time": "2020-09-25T11:33:37.227749Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating validation dataset of 0.01 of training for adaptive regularization\n",
      "-- Epoch 1\n",
      "Training log loss: 0.38084\n"
     ]
    }
   ],
   "source": [
    "fm = pylibfm.FM()\n",
    "fm.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T11:34:25.719204Z",
     "start_time": "2020-09-25T11:34:25.714261Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.99212296])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试集\n",
    "test = v.transform({'user': \"1\", 'item': \"10\", 'age': 24})\n",
    "fm.predict(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测Alice的评分\n",
    "这个还是用前面协同过滤和矩阵分解的那个例子玩一下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:57.743357Z",
     "start_time": "2020-09-26T00:46:57.738370Z"
    }
   },
   "outputs": [],
   "source": [
    "def loadData():\n",
    "    rating_data={1: {'A': 5, 'B': 3, 'C': 4, 'D': 4},\n",
    "           2: {'A': 3, 'B': 1, 'C': 2, 'D': 3, 'E': 3},\n",
    "           3: {'A': 4, 'B': 3, 'C': 4, 'D': 3, 'E': 5},\n",
    "           4: {'A': 3, 'B': 3, 'C': 1, 'D': 5, 'E': 4},\n",
    "           5: {'A': 1, 'B': 5, 'C': 5, 'D': 2, 'E': 1}\n",
    "          }\n",
    "    return rating_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:58.957723Z",
     "start_time": "2020-09-26T00:46:58.954744Z"
    }
   },
   "outputs": [],
   "source": [
    "rating_data = loadData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:46:59.658705Z",
     "start_time": "2020-09-26T00:46:59.650708Z"
    }
   },
   "outputs": [],
   "source": [
    "df = pd.DataFrame(rating_data).T\n",
    "df = df.stack().reset_index()\n",
    "df.columns = ['user', 'item', 'rating']\n",
    "df['user'] = df['user'].astype('str')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T12:53:48.168227Z",
     "start_time": "2020-09-25T12:53:48.164188Z"
    }
   },
   "outputs": [],
   "source": [
    "item_map = {item: str(idx) for idx, item in enumerate(set(df['item']))}\n",
    "df['item'] = df['item'].map(item_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:51:49.479736Z",
     "start_time": "2020-09-26T00:51:49.474749Z"
    }
   },
   "outputs": [],
   "source": [
    "train_data = df[['user', 'item']]\n",
    "y = df['rating']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:51:50.246102Z",
     "start_time": "2020-09-26T00:51:50.242113Z"
    }
   },
   "outputs": [],
   "source": [
    "one = OneHotEncoder()\n",
    "x = one.fit_transform(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-26T00:52:06.805968Z",
     "start_time": "2020-09-26T00:52:06.801978Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<24x10 sparse matrix of type '<class 'numpy.float64'>'\n",
       "\twith 48 stored elements in Compressed Sparse Row format>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T12:54:09.543543Z",
     "start_time": "2020-09-25T12:54:09.539553Z"
    }
   },
   "outputs": [],
   "source": [
    "# 建立模型\n",
    "fm = pylibfm.FM(num_factors=10, num_iter=100, verbose=True, task='regression', initial_learning_rate=0.001, learning_rate_schedule='optimal')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T12:54:11.674099Z",
     "start_time": "2020-09-25T12:54:11.588094Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating validation dataset of 0.01 of training for adaptive regularization\n",
      "-- Epoch 1\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 2\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 3\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 4\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 5\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 6\n",
      "Training MSE: 3.36957\n",
      "-- Epoch 7\n",
      "Training MSE: 3.36824\n",
      "-- Epoch 8\n",
      "Training MSE: 3.18354\n",
      "-- Epoch 9\n",
      "Training MSE: 2.90356\n",
      "-- Epoch 10\n",
      "Training MSE: 2.65204\n",
      "-- Epoch 11\n",
      "Training MSE: 2.43094\n",
      "-- Epoch 12\n",
      "Training MSE: 2.23572\n",
      "-- Epoch 13\n",
      "Training MSE: 2.06420\n",
      "-- Epoch 14\n",
      "Training MSE: 1.91343\n",
      "-- Epoch 15\n",
      "Training MSE: 1.78005\n",
      "-- Epoch 16\n",
      "Training MSE: 1.66318\n",
      "-- Epoch 17\n",
      "Training MSE: 1.56070\n",
      "-- Epoch 18\n",
      "Training MSE: 1.46989\n",
      "-- Epoch 19\n",
      "Training MSE: 1.38988\n",
      "-- Epoch 20\n",
      "Training MSE: 1.31918\n",
      "-- Epoch 21\n",
      "Training MSE: 1.25672\n",
      "-- Epoch 22\n",
      "Training MSE: 1.20164\n",
      "-- Epoch 23\n",
      "Training MSE: 1.15306\n",
      "-- Epoch 24\n",
      "Training MSE: 1.10965\n",
      "-- Epoch 25\n",
      "Training MSE: 1.07162\n",
      "-- Epoch 26\n",
      "Training MSE: 1.03799\n",
      "-- Epoch 27\n",
      "Training MSE: 1.00777\n",
      "-- Epoch 28\n",
      "Training MSE: 0.98110\n",
      "-- Epoch 29\n",
      "Training MSE: 0.95754\n",
      "-- Epoch 30\n",
      "Training MSE: 0.93635\n",
      "-- Epoch 31\n",
      "Training MSE: 0.91735\n",
      "-- Epoch 32\n",
      "Training MSE: 0.90037\n",
      "-- Epoch 33\n",
      "Training MSE: 0.88507\n",
      "-- Epoch 34\n",
      "Training MSE: 0.87155\n",
      "-- Epoch 35\n",
      "Training MSE: 0.85941\n",
      "-- Epoch 36\n",
      "Training MSE: 0.84821\n",
      "-- Epoch 37\n",
      "Training MSE: 0.83823\n",
      "-- Epoch 38\n",
      "Training MSE: 0.82916\n",
      "-- Epoch 39\n",
      "Training MSE: 0.82078\n",
      "-- Epoch 40\n",
      "Training MSE: 0.81315\n",
      "-- Epoch 41\n",
      "Training MSE: 0.80625\n",
      "-- Epoch 42\n",
      "Training MSE: 0.79974\n",
      "-- Epoch 43\n",
      "Training MSE: 0.79388\n",
      "-- Epoch 44\n",
      "Training MSE: 0.78841\n",
      "-- Epoch 45\n",
      "Training MSE: 0.78327\n",
      "-- Epoch 46\n",
      "Training MSE: 0.77852\n",
      "-- Epoch 47\n",
      "Training MSE: 0.77408\n",
      "-- Epoch 48\n",
      "Training MSE: 0.76979\n",
      "-- Epoch 49\n",
      "Training MSE: 0.76581\n",
      "-- Epoch 50\n",
      "Training MSE: 0.76203\n",
      "-- Epoch 51\n",
      "Training MSE: 0.75832\n",
      "-- Epoch 52\n",
      "Training MSE: 0.75492\n",
      "-- Epoch 53\n",
      "Training MSE: 0.75166\n",
      "-- Epoch 54\n",
      "Training MSE: 0.74844\n",
      "-- Epoch 55\n",
      "Training MSE: 0.74539\n",
      "-- Epoch 56\n",
      "Training MSE: 0.74239\n",
      "-- Epoch 57\n",
      "Training MSE: 0.73950\n",
      "-- Epoch 58\n",
      "Training MSE: 0.73667\n",
      "-- Epoch 59\n",
      "Training MSE: 0.73394\n",
      "-- Epoch 60\n",
      "Training MSE: 0.73115\n",
      "-- Epoch 61\n",
      "Training MSE: 0.72856\n",
      "-- Epoch 62\n",
      "Training MSE: 0.72602\n",
      "-- Epoch 63\n",
      "Training MSE: 0.72342\n",
      "-- Epoch 64\n",
      "Training MSE: 0.72092\n",
      "-- Epoch 65\n",
      "Training MSE: 0.71846\n",
      "-- Epoch 66\n",
      "Training MSE: 0.71599\n",
      "-- Epoch 67\n",
      "Training MSE: 0.71357\n",
      "-- Epoch 68\n",
      "Training MSE: 0.71116\n",
      "-- Epoch 69\n",
      "Training MSE: 0.70875\n",
      "-- Epoch 70\n",
      "Training MSE: 0.70641\n",
      "-- Epoch 71\n",
      "Training MSE: 0.70408\n",
      "-- Epoch 72\n",
      "Training MSE: 0.70169\n",
      "-- Epoch 73\n",
      "Training MSE: 0.69938\n",
      "-- Epoch 74\n",
      "Training MSE: 0.69707\n",
      "-- Epoch 75\n",
      "Training MSE: 0.69475\n",
      "-- Epoch 76\n",
      "Training MSE: 0.69244\n",
      "-- Epoch 77\n",
      "Training MSE: 0.69016\n",
      "-- Epoch 78\n",
      "Training MSE: 0.68780\n",
      "-- Epoch 79\n",
      "Training MSE: 0.68556\n",
      "-- Epoch 80\n",
      "Training MSE: 0.68327\n",
      "-- Epoch 81\n",
      "Training MSE: 0.68097\n",
      "-- Epoch 82\n",
      "Training MSE: 0.67868\n",
      "-- Epoch 83\n",
      "Training MSE: 0.67640\n",
      "-- Epoch 84\n",
      "Training MSE: 0.67407\n",
      "-- Epoch 85\n",
      "Training MSE: 0.67179\n",
      "-- Epoch 86\n",
      "Training MSE: 0.66950\n",
      "-- Epoch 87\n",
      "Training MSE: 0.66715\n",
      "-- Epoch 88\n",
      "Training MSE: 0.66489\n",
      "-- Epoch 89\n",
      "Training MSE: 0.66258\n",
      "-- Epoch 90\n",
      "Training MSE: 0.66024\n",
      "-- Epoch 91\n",
      "Training MSE: 0.65795\n",
      "-- Epoch 92\n",
      "Training MSE: 0.65560\n",
      "-- Epoch 93\n",
      "Training MSE: 0.65328\n",
      "-- Epoch 94\n",
      "Training MSE: 0.65093\n",
      "-- Epoch 95\n",
      "Training MSE: 0.64859\n",
      "-- Epoch 96\n",
      "Training MSE: 0.64617\n",
      "-- Epoch 97\n",
      "Training MSE: 0.64383\n",
      "-- Epoch 98\n",
      "Training MSE: 0.64149\n",
      "-- Epoch 99\n",
      "Training MSE: 0.63907\n",
      "-- Epoch 100\n",
      "Training MSE: 0.63669\n"
     ]
    }
   ],
   "source": [
    "fm.fit(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T12:54:53.105145Z",
     "start_time": "2020-09-25T12:54:53.100160Z"
    }
   },
   "outputs": [],
   "source": [
    "# 测试集\n",
    "test = {'user': '1', 'item': '4'}\n",
    "x_test = one.transform(pd.DataFrame(test, index=[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-25T12:55:41.392583Z",
     "start_time": "2020-09-25T12:55:41.388596Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FM的预测评分:3.513755892491899\n"
     ]
    }
   ],
   "source": [
    "pred_rating = fm.predict(x_test)\n",
    "print('FM的预测评分:{}'.format(pred_rating[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
